# -*-Python-*-

import mesh_tensorflow.transformer.transformer_layers

encoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
]
decoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.TransparentEncDecAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
]

TransparentEncDecAttention.dropout_rate = 0.1
TransparentEncDecAttention.layers_per_encoder_module = 2
TransparentEncDecAttention.layers_per_decoder_module = 3
TransparentEncDecAttention.encoder_num_modules = 6
TransparentEncDecAttention.decoder_num_modules = 6

make_layer_stack.num_layers = 6
